import torch
import torch.nn as nn

from torch.distributions import constraints

from dpp.flows.base import Flow
from dpp.nn import Hypernet
from dpp.utils import clamp_preserve_gradients


class FixedAffine(Flow):
    """Affine transformation y = ax + b with fixed parameters."""
    domain = constraints.real
    codomain = constraints.real

    def __init__(self, scale_init=1.0, shift_init=0.0, use_shift=False, trainable=False):
        super().__init__()
        log_scale_init = torch.tensor([scale_init]).log()
        self.log_scale = nn.Parameter(log_scale_init, requires_grad=trainable)
        if use_shift:
            shift_init = torch.tensor([shift_init])
            self.shift = nn.Parameter(shift_init, requires_grad=trainable)
        else:
            self.shift = 0.0

    def forward(self, x, **kwargs):
        y = torch.exp(self.log_scale) * x + self.shift
        log_det_jac = self.log_scale.expand(y.shape)
        return y, log_det_jac

    def inverse(self, y, **kwargs):
        x = (y - self.shift) * torch.exp(-self.log_scale)
        inv_log_det_jac = -self.log_scale.expand(x.shape)
        return x, inv_log_det_jac


class HyperAffine(Flow):
    """Affine transformation where the parameters are generated by a hypernet."""
    domain = constraints.real
    codomain = constraints.real

    def __init__(self, config, min_clip=-5.0, max_clip=3.0):
        super().__init__()
        self.use_history(config.use_history)
        self.use_embedding(config.use_embedding)

        self.min_clip = min_clip
        self.max_clip = max_clip
        self.hypernet = Hypernet(config, param_sizes=[1, 1])

    def get_params(self, h, emb):
        if not self.using_history:
            h = None
        if not self.using_embedding:
            emb = None
        if self.using_history or self.using_embedding:
            log_scale, shift = self.hypernet(h, emb)
            log_scale = clamp_preserve_gradients(log_scale, self.min_clip, self.max_clip)
        return log_scale.squeeze(-1), shift.squeeze(-1)

    def forward(self, x, h=None, emb=None):
        """Forward transformation.

        Args:
            x: Samples to transform. shape (*)
            h: History for each sample. Shape should match x (except last dim).
            shape (*, history_size)
            emb: Embeddings for each sample. Shape should match x (except last dim).
            shape (*, embedding_size)
        """
        log_scale, shift = self.get_params(h, emb)
        y = torch.exp(log_scale) * x + shift
        log_det_jac = log_scale.expand(y.shape)
        return y, log_det_jac

    def inverse(self, y, h=None, emb=None):
        """Inverse transformation.

        Args:
            x: Samples to transform. shape (*)
            h: History for each sample. Shape should match x (except last dim).
            shape (*, history_size)
            emb: Embeddings for each sample. Shape should match x (except last dim).
            shape (*, embedding_size)
        """
        log_scale, shift = self.get_params(h, emb)
        x = (y - shift) * torch.exp(-log_scale)
        inv_log_det_jac = -log_scale.expand(x.shape)
        return x, inv_log_det_jac
